# -*- coding:utf-8 -*-


from config.load_config import load_config
from embed.load_embedding import build_embed
from dataset.build_dataset import build_dataset
from model.load_model import model_train, model_predict

"""
模型主流程
"""


def train(config_path):
    """
    训练
    :return:
    """
    # 配置文件模块
    load_config(config_path, is_train=True)
    # 词嵌入模块
    build_embed()
    # 数据集管道构建
    build_dataset()
    # 训练
    model_train()


def predict(config_path, x):
    """
    预测
    :return:
    """
    # 预测
    load_config(config_path, is_train=False)
    build_embed()
    y_pred = model_predict(x)
    return y_pred



